{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Modify this block to specify your paths"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "question_pt = '' # path to val_questions.pt\n",
    "feature_h5 = '' # path to trainval_feature.h5\n",
    "vocab_json = '' # path to vocab.json\n",
    "ckpt = '' # path to checkpoint model.pt\n",
    "image_dir = '' # path to mscoco/val2014, containing all mscoco val images\n",
    "ann_file = '' # path to mscoco/annotations/instances_val2014.json"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load data and checkpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "from pycocotools.coco import COCO\n",
    "import skimage.io as io\n",
    "import matplotlib.pyplot as plt\n",
    "import pylab\n",
    "import os,sys\n",
    "sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[1])))) # to import shared utils\n",
    "import torch\n",
    "import numpy as np\n",
    "import pickle\n",
    "import json\n",
    "import h5py\n",
    "import copy\n",
    "from PIL import Image, ImageFont, ImageDraw, ImageEnhance\n",
    "from utils.misc import todevice\n",
    "from model.net import XNMNet\n",
    "from model.composite_modules import MODULE_INPUT_NUM\n",
    "module_names = list(MODULE_INPUT_NUM.keys())\n",
    "print(module_names)\n",
    "\n",
    "coco = COCO(ann_file)\n",
    "\n",
    "def invert_dict(d):\n",
    "    return {v: k for k, v in d.items()}\n",
    "\n",
    "def expand_batch(*args):\n",
    "    return (t.unsqueeze(0) for t in args)\n",
    "\n",
    "with open(vocab_json, 'r') as f:\n",
    "    vocab = json.load(f)\n",
    "    vocab['question_idx_to_token'] = invert_dict(vocab['question_token_to_idx'])\n",
    "    vocab['answer_idx_to_token'] = invert_dict(vocab['answer_token_to_idx'])\n",
    "    vocab['program_idx_to_token'] = invert_dict(vocab['program_token_to_idx'])\n",
    "\n",
    "with open(question_pt, 'rb') as f:\n",
    "    obj = pickle.load(f)\n",
    "    questions = torch.LongTensor(np.asarray(obj['questions']))\n",
    "    questions_len = torch.LongTensor(np.asarray(obj['questions_len']))\n",
    "    q_image_indices = obj['image_idxs'] # coco_id\n",
    "    answers = obj['answers']\n",
    "\n",
    "with h5py.File(feature_h5, 'r') as f:\n",
    "    coco_ids = f['ids'][()]\n",
    "feat_coco_id_to_index = {id: i for i, id in enumerate(coco_ids)}\n",
    "\n",
    "\n",
    "def getitem(index):\n",
    "    q = questions[index]\n",
    "    q_len = questions_len[index]\n",
    "    a = answers[index]\n",
    "    coco_id = q_image_indices[index]\n",
    "    index = feat_coco_id_to_index[coco_id] # feature index\n",
    "    with h5py.File(feature_h5, 'r') as f:\n",
    "        vision_feat = f['features'][index]\n",
    "        boxes = f['boxes'][index]\n",
    "        w = f['widths'][index]\n",
    "        h = f['heights'][index]\n",
    "    vision_feat = torch.from_numpy(vision_feat).float()\n",
    "    num_feat = boxes.shape[1]\n",
    "    relation_mask = np.zeros((num_feat, num_feat))\n",
    "    for i in range(num_feat):\n",
    "        for j in range(i+1, num_feat):\n",
    "            # if there is no overlap between two bounding box\n",
    "            if boxes[0,i]>boxes[2,j] or boxes[0,j]>boxes[2,i] or boxes[1,i]>boxes[3,j] or boxes[1,j]>boxes[3,i]:\n",
    "                pass\n",
    "            else:\n",
    "                relation_mask[i,j] = relation_mask[j,i] = 1\n",
    "    relation_mask = torch.from_numpy(relation_mask).byte()\n",
    "    q, q_len, vision_feat, relation_mask = expand_batch(q, q_len, vision_feat, relation_mask)\n",
    "    return coco_id, boxes, w, h, a, q, q_len, vision_feat, relation_mask\n",
    "\n",
    "\n",
    "device = 'cuda'\n",
    "loaded = torch.load(ckpt, map_location={'cuda:0': 'cpu'})\n",
    "model_kwargs = loaded['model_kwargs']\n",
    "model_kwargs.update({'vocab': vocab, 'device': device})\n",
    "model = XNMNet(**model_kwargs).to(device)\n",
    "model.load_state_dict(loaded['state_dict'])\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# run model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "index = 54 # change this index to get different inputs\n",
    "\n",
    "coco_id, boxes, img_width, img_height, answer, *batch = getitem(index)\n",
    "batch = [todevice(x, device) for x in batch]\n",
    "img_obj = coco.loadImgs([coco_id])[0]\n",
    "img = io.imread(os.path.join(image_dir, img_obj['file_name']))\n",
    "# print(img.shape, img_height, img_width)\n",
    "pylab.rcParams['figure.figsize'] = (8.0, 10.0)\n",
    "plt.axis('off')\n",
    "plt.imshow(img)\n",
    "q, q_len = batch[:2]\n",
    "q_str = ' '.join([vocab['question_idx_to_token'][i.item()] for i in q[0, :q_len]])\n",
    "print(q_str)\n",
    "\n",
    "logits, others = model(*batch)\n",
    "predict = torch.max(logits.squeeze(), dim=0)[1].item()\n",
    "answer = np.argmax(np.bincount(answer))\n",
    "print(\"predict: %s; answer: %s\" % (vocab['answer_idx_to_token'][predict], vocab['answer_idx_to_token'][answer]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# visualize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "def drawrect(drawcontext, xy, outline=None, width=0):\n",
    "    x1, y1, x2, y2 = xy\n",
    "    points = (x1, y1), (x2, y1), (x2, y2), (x1, y2), (x1, y1)\n",
    "    drawcontext.line(points, fill=outline, width=width)\n",
    "\n",
    "def plot_attention(img, boxes, att):\n",
    "    white = np.asarray([255, 255, 255])\n",
    "    pixel_peak = np.zeros((img.shape[0], img.shape[1]))\n",
    "    for k in range(36):\n",
    "        for i in range(int(boxes[1][k]), int(boxes[3][k])):\n",
    "            for j in range(int(boxes[0][k]), int(boxes[2][k])):\n",
    "                pixel_peak[i,j] = max(pixel_peak[i,j], att[k])\n",
    "    for i in range(0, img.shape[0]):\n",
    "        for j in range(0, img.shape[1]):\n",
    "            img[i,j] = white * (1-pixel_peak[i,j]) + img[i,j] * pixel_peak[i,j]\n",
    "    red_box = boxes[:, np.argmax(att)]\n",
    "    img = Image.fromarray(np.uint8(img))\n",
    "    draw = ImageDraw.Draw(img)\n",
    "    drawrect(draw, red_box, outline='red', width=4)\n",
    "    img = np.asarray(img)\n",
    "    return img\n",
    "\n",
    "\n",
    "num_module = len(module_names)\n",
    "pylab.rcParams['figure.figsize'] = (10.0, 20.0)\n",
    "\n",
    "for t in range(model_kwargs['T_ctrl']):\n",
    "    print(t)\n",
    "    input_att = list(zip(q_str.split(), others['cv'][t][0, :q_len].tolist()))\n",
    "    print(input_att)\n",
    "    for m in range(num_module):\n",
    "        print(\"%s: %.3f\" % (module_names[m], others['module_prob'][t][m][0]))\n",
    "        a = others['att'][t][m][0][:,0].tolist()\n",
    "        plot_img = plot_attention(copy.copy(img), boxes, a)\n",
    "        plt.axis('off')\n",
    "        plt.imshow(plot_img)\n",
    "        plt.pause(0.0001)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
